import numpy as np
a = np.array([
              [1,5,5,2],
              [9,6,2,7],
              [3,7,9,2]
              ])
print(np.argmax(a,axis=1))
